{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "foo\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from src.utils import get_wb, foo\n",
    "from src.utils import print_params_dict, print_params_number\n",
    "\n",
    "from tensorlayer.layers import DenseLayer, RNNLayer, InputLayer, FlattenLayer\n",
    "\n",
    "from src.layers import dense, multi_dense, linear_dense, Dense\n",
    "from src.layers import highway_dense, multi_highway_dense\n",
    "from src.layers import conv2d\n",
    "from src.layers import highway_conv2d, multi_highway_conv2d\n",
    "from src.layers import attention_for_rnn, attention_for_dense\n",
    "from src.layers import attention_flow_self, attention_flow\n",
    "\n",
    "foo()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## get_wb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 16)\n",
      "[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]\n",
      "W:0 \t{'number': 32, 'shape': [2, 16]}\n",
      "b:0 \t{'number': 16, 'shape': [16]}\n"
     ]
    }
   ],
   "source": [
    "# get_wb\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 16]\n",
    "w, b = get_wb([2, 16], b_initializer=tf.initializers.constant(-1.0))\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = tf.matmul(x, w) + b\n",
    "    b_val = tf.identity(b).eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print(b_val)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 16)\n",
      "D/W:0 \t{'number': 32, 'shape': [2, 16]}\n",
      "D/b:0 \t{'number': 16, 'shape': [16]}\n"
     ]
    }
   ],
   "source": [
    "# Dense\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 16]\n",
    "o = dense(x, 16, name=\"D\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dense reuse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "D/W:0 \t{'number': 32, 'shape': [2, 16]}\n",
      "D/b:0 \t{'number': 16, 'shape': [16]}\n"
     ]
    }
   ],
   "source": [
    "# Dense reuse\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 16]\n",
    "o = dense(x, 16, name=\"D\")\n",
    "o2 = dense(x, 16, name=\"D\", reuse=True)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    o2_val = o2.eval()\n",
    "    \n",
    "print((o_val == o2_val).all())\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dense reuse with class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "DD/W:0 \t{'number': 32, 'shape': [2, 16]}\n",
      "DD/b:0 \t{'number': 16, 'shape': [16]}\n"
     ]
    }
   ],
   "source": [
    "# Dense class\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 16]\n",
    "dense_reuse = Dense(16, name=\"DD\")\n",
    "o = dense_reuse(x)\n",
    "o2 = dense_reuse(x)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    o2_val = o2.eval()\n",
    "    \n",
    "print((o_val == o2_val).all())\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 5)\n",
      "D-0/W:0 \t{'number': 6, 'shape': [2, 3]}\n",
      "D-0/b:0 \t{'number': 3, 'shape': [3]}\n",
      "D-1/W:0 \t{'number': 12, 'shape': [3, 4]}\n",
      "D-1/b:0 \t{'number': 4, 'shape': [4]}\n",
      "D-2/W:0 \t{'number': 20, 'shape': [4, 5]}\n",
      "D-2/b:0 \t{'number': 5, 'shape': [5]}\n"
     ]
    }
   ],
   "source": [
    "# Multi Dense\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 5]\n",
    "o = multi_dense(x, [3, 4, 5], name=\"D\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Highway Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 2)\n",
      "H/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H/transform/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H/transform/b:0 \t{'number': 2, 'shape': [2]}\n"
     ]
    }
   ],
   "source": [
    "# Highway Dense\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 2]\n",
    "o = highway_dense(x, name=\"H\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi Highway Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 2)\n",
      "H1-0/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-0/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H1-0/transform/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-0/transform/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H1-1/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-1/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H1-1/transform/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-1/transform/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H1-2/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-2/b:0 \t{'number': 2, 'shape': [2]}\n",
      "H1-2/transform/W:0 \t{'number': 4, 'shape': [2, 2]}\n",
      "H1-2/transform/b:0 \t{'number': 2, 'shape': [2]}\n"
     ]
    }
   ],
   "source": [
    "# Multi Highway Dense\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [8, 2]\n",
    "x = tf.constant(np.arange(16, dtype=np.float32).reshape([8,2]))\n",
    "# Output shape: [8, 2]\n",
    "o = multi_highway_dense(x, 3, name=\"H1\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 20, 20, 16)\n",
      "conv2d/W:0 \t{'number': 144, 'shape': [3, 3, 1, 16]}\n",
      "conv2d/b:0 \t{'number': 16, 'shape': [16]}\n"
     ]
    }
   ],
   "source": [
    "# Conv2d\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [4, 20, 20, 1]\n",
    "x = tf.constant(np.arange(1600, dtype=np.float32).reshape([4, 20, 20, 1]))\n",
    "# Output shape: [4, 20, 20, 16]\n",
    "o = conv2d(x, 3, 16)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Highway Conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 20, 20, 1)\n",
      "highway_conv2d/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d/b:0 \t{'number': 1, 'shape': [1]}\n",
      "highway_conv2d/transform/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d/transform/b:0 \t{'number': 1, 'shape': [1]}\n"
     ]
    }
   ],
   "source": [
    "# Highway Conv2d\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [4, 20, 20, 1]\n",
    "x = tf.constant(np.arange(1600, dtype=np.float32).reshape([4, 20, 20, 1]))\n",
    "# Output shape: [4, 20, 20, 16]\n",
    "o = highway_conv2d(x, 3)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi Highway Conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 20, 20, 1)\n",
      "HCnn1/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "HCnn1/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HCnn1/transform/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "HCnn1/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HCnn2/W:0 \t{'number': 16, 'shape': [4, 4, 1, 1]}\n",
      "HCnn2/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HCnn2/transform/W:0 \t{'number': 16, 'shape': [4, 4, 1, 1]}\n",
      "HCnn2/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HCnn3/W:0 \t{'number': 25, 'shape': [5, 5, 1, 1]}\n",
      "HCnn3/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HCnn3/transform/W:0 \t{'number': 25, 'shape': [5, 5, 1, 1]}\n",
      "HCnn3/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "106\n"
     ]
    }
   ],
   "source": [
    "# Multi Highway Conv2d\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [4, 20, 20, 1]\n",
    "x = tf.constant(np.arange(1600, dtype=np.float32).reshape([-1, 20, 20, 1]))\n",
    "# Output shape: [4, 20, 20, 1]\n",
    "x = highway_conv2d(x, 3, name=\"HCnn1\")\n",
    "x = highway_conv2d(x, 4, name=\"HCnn2\")\n",
    "o = highway_conv2d(x, 5, name=\"HCnn3\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi Highway Conv2d with function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 20, 20, 1)\n",
      "(4, 20, 20, 1)\n",
      "highway_conv2d-0/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d-0/b:0 \t{'number': 1, 'shape': [1]}\n",
      "highway_conv2d-0/transform/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d-0/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "highway_conv2d-1/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d-1/b:0 \t{'number': 1, 'shape': [1]}\n",
      "highway_conv2d-1/transform/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "highway_conv2d-1/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-0/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "HC-0/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-0/transform/W:0 \t{'number': 9, 'shape': [3, 3, 1, 1]}\n",
      "HC-0/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-1/W:0 \t{'number': 16, 'shape': [4, 4, 1, 1]}\n",
      "HC-1/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-1/transform/W:0 \t{'number': 16, 'shape': [4, 4, 1, 1]}\n",
      "HC-1/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-2/W:0 \t{'number': 25, 'shape': [5, 5, 1, 1]}\n",
      "HC-2/b:0 \t{'number': 1, 'shape': [1]}\n",
      "HC-2/transform/W:0 \t{'number': 25, 'shape': [5, 5, 1, 1]}\n",
      "HC-2/transform/b:0 \t{'number': 1, 'shape': [1]}\n",
      "146\n"
     ]
    }
   ],
   "source": [
    "# Multi Highway Conv2d\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [4, 20, 20, 1]\n",
    "x = tf.constant(np.arange(1600, dtype=np.float32).reshape([-1, 20, 20, 1]))\n",
    "# Output shape: [4, 20, 20, 1]\n",
    "o = multi_highway_conv2d(x, 3, 2)\n",
    "o2 = multi_highway_conv2d(x, [3,4,5], 3, name=\"HC\")\n",
    "\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    o2_val = o2.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print(o2_val.shape)\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention for Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 8)\n",
      "attention_for_dense/W:0 \t{'number': 64, 'shape': [8, 8]}\n",
      "attention_for_dense/b:0 \t{'number': 8, 'shape': [8]}\n",
      "72\n"
     ]
    }
   ],
   "source": [
    "# Attention for Dense\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [4, 8]\n",
    "x = tf.constant(np.arange(32, dtype=np.float32).reshape([-1, 8]))\n",
    "# Output shape: [4, 8]\n",
    "o = attention_for_dense(x)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention for RNN (use_mean_attention=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 5, 32)\n",
      "attention_for_rnn/W:0 \t{'number': 25, 'shape': [5, 5]}\n",
      "attention_for_rnn/b:0 \t{'number': 5, 'shape': [5]}\n",
      "30\n"
     ]
    }
   ],
   "source": [
    "# Attention for RNN\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [128, 5, 32]\n",
    "x = tf.constant(np.arange(20480, dtype=np.float32).reshape([128, 5, 32]))\n",
    "# Output shape: [128, 5, 32]\n",
    "o = attention_for_rnn(x)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention for RNN (use_mean_attention=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 5, 32)\n",
      "attention_for_rnn/W:0 \t{'number': 25, 'shape': [5, 5]}\n",
      "attention_for_rnn/b:0 \t{'number': 5, 'shape': [5]}\n",
      "30\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from src.utils import get_wb, foo\n",
    "from src.utils import print_params_dict, print_params_number\n",
    "\n",
    "from src.layers import attention_for_rnn\n",
    "\n",
    "# Attention for RNN\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [128, 5, 32]\n",
    "x = tf.constant(np.arange(20480, dtype=np.float32).reshape([128, 5, 32]))\n",
    "# Output shape: [128, 5, 32]\n",
    "o = attention_for_rnn(x, n_step=5, use_mean_attention=True)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    \n",
    "print(o_val.shape)\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use attention after lstm (use tensorlayer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[TL] InputLayer  input: (128, 16, 5)\n",
      "[TL] RNNLayer rnn: n_hidden: 32 n_steps: 5 in_dim: 3 in_shape: (128, 16, 5) cell_fn: LSTMCell \n",
      "[TL]        RNN batch_size (concurrent processes): 128\n",
      "[TL]      n_params : 2\n",
      "[TL] InputLayer  input: (128, 5, 32)\n",
      "[TL] FlattenLayer flatten: 160\n",
      "[TL] DenseLayer  dense: 1 sigmoid\n",
      "(128, 1)\n",
      "rnn/lstm_cell/kernel:0 \t{'number': 4736, 'shape': [37, 128]}\n",
      "rnn/lstm_cell/bias:0 \t{'number': 128, 'shape': [128]}\n",
      "attention_for_rnn/W:0 \t{'number': 25, 'shape': [5, 5]}\n",
      "attention_for_rnn/b:0 \t{'number': 5, 'shape': [5]}\n",
      "dense/W:0 \t{'number': 160, 'shape': [160, 1]}\n",
      "dense/b:0 \t{'number': 1, 'shape': [1]}\n",
      "5055\n"
     ]
    }
   ],
   "source": [
    "# Attention for RNN\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [128, 16, 5]\n",
    "x = tf.constant(np.arange(10240, dtype=np.float32).reshape([128, 16, 5]))\n",
    "x = InputLayer(x)\n",
    "\n",
    "# Use attention after lstm\n",
    "x = RNNLayer(x, tf.nn.rnn_cell.LSTMCell, n_hidden=32)\n",
    "x = attention_for_rnn(x.outputs)\n",
    "\n",
    "x = InputLayer(x)\n",
    "x = FlattenLayer(x)\n",
    "x = DenseLayer(x, n_units=1, act=tf.nn.sigmoid)\n",
    "o = x.outputs\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "\n",
    "print(o_val.shape)  # [128, 1]\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use attention before lstm (use tensorlayer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[TL] InputLayer  input: (128, 16, 5)\n",
      "[TL] RNNLayer rnn: n_hidden: 32 n_steps: 5 in_dim: 3 in_shape: (128, 16, 5) cell_fn: LSTMCell \n",
      "[TL]        RNN batch_size (concurrent processes): 128\n",
      "[TL]      n_params : 2\n",
      "[TL] DenseLayer  dense: 1 sigmoid\n",
      "(128, 1)\n",
      "attention_for_rnn/W:0 \t{'number': 256, 'shape': [16, 16]}\n",
      "attention_for_rnn/b:0 \t{'number': 16, 'shape': [16]}\n",
      "rnn/lstm_cell/kernel:0 \t{'number': 4736, 'shape': [37, 128]}\n",
      "rnn/lstm_cell/bias:0 \t{'number': 128, 'shape': [128]}\n",
      "dense/W:0 \t{'number': 32, 'shape': [32, 1]}\n",
      "dense/b:0 \t{'number': 1, 'shape': [1]}\n",
      "5169\n"
     ]
    }
   ],
   "source": [
    "# Attention for RNN\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape: [128, 16, 5]\n",
    "x = tf.constant(np.arange(10240, dtype=np.float32).reshape([128, 16, 5]))\n",
    "\n",
    "# Use attention before lstm\n",
    "x = attention_for_rnn(x)\n",
    "x = InputLayer(x)\n",
    "x = RNNLayer(x, tf.nn.rnn_cell.LSTMCell, n_hidden=32, return_last=True)\n",
    "\n",
    "# x = FlattenLayer(x)\n",
    "x = DenseLayer(x, n_units=1, act=tf.nn.sigmoid)\n",
    "o = x.outputs\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "\n",
    "print(o_val.shape)  # [128, 1]\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention Flow Self Match Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 10, 128)\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "#from src.__ref import AttentionFlowMatchLayer\n",
    "\n",
    "# Attention Flow Self Match Layer\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape\n",
    "c = tf.constant(np.arange(40960, dtype=np.float32).reshape([128, 10, 32]))\n",
    "q = tf.constant(np.arange(20480, dtype=np.float32).reshape([128, 5, 32]))\n",
    "\n",
    "o = attention_flow_self(c, q)\n",
    "# o2 = AttentionFlowMatchLayer(0).match(c, q, _, _)[0]\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    #o2_val = o2.eval()\n",
    "\n",
    "print(o_val.shape)  # [128, 1]\n",
    "#print((o_val==o2_val).all())\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage of `tf.einsum`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 3, 4, 1)\n",
      "True\n",
      "W:0 \t{'number': 30, 'shape': [30, 1]}\n",
      "b:0 \t{'number': 1, 'shape': [1]}\n",
      "31\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "T, J, d = 3, 4, 10\n",
    "h = tf.constant(np.arange(3600, dtype=np.float32).reshape([10, T, J, 3*d]))\n",
    "\n",
    "w, b = get_wb([3*d, 1])\n",
    "\n",
    "# tf.matmul(w, h)\n",
    "o = tf.einsum(\"ntjd,do->ntjo\", h, w)\n",
    "o2 = tf.matmul(tf.reshape(h, [-1, 3*d]), w)\n",
    "o2 = tf.reshape(o2,[-1, 3, 4, 1])\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    o2_val = o2.eval()\n",
    "\n",
    "print(o_val.shape)  # [128, 1]\n",
    "print((o_val==o2_val).all())\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention Flow Match Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 10, 128)\n",
      "-----------\n",
      "params_dict\n",
      "  af/W:0 \t{'number': 96, 'shape': [96, 1]}\n",
      "-----------\n",
      "params_number: 96\n"
     ]
    }
   ],
   "source": [
    "#from src.layers import attention_flow_2\n",
    "\n",
    "# Attention Flow Match Layer\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Input shape\n",
    "c = tf.constant(np.arange(40960, dtype=np.float32).reshape([128, 10, 32]))\n",
    "q = tf.constant(np.arange(20480, dtype=np.float32).reshape([128, 5, 32]))\n",
    "\n",
    "o = attention_flow(c, q, name=\"af\")\n",
    "#o2 = attention_flow_2(c, q, name=\"af\", reuse=True)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    tf.global_variables_initializer().run()\n",
    "    o_val = o.eval()\n",
    "    #o2_val = o2.eval()\n",
    "\n",
    "print(o_val.shape)  # [128, 1]\n",
    "#print((o_val==o2_val).all())\n",
    "print_params_dict()\n",
    "print_params_number()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
